Skip to content

ggml-webgpu: address precision issues for multimodal #22808

Merged
reeselevine merged 13 commits into
ggml-org:masterfrom
noumena-labs:webgpu/fix-mul-max
May 12, 2026
Merged

ggml-webgpu: address precision issues for multimodal #22808
reeselevine merged 13 commits into
ggml-org:masterfrom
noumena-labs:webgpu/fix-mul-max

Conversation

@Constannnnnt
Copy link
Copy Markdown
Contributor

@Constannnnnt Constannnnnt commented May 7, 2026

Overview

In this PR, I addressed the precision issues for multimodal. More specifically, when mixed types are used in models and projectors, I use f32 for precision in the flash attention (more specifically, in the tile path) for the browser. I did not edit flash_attn.wgsl since subgroup_matrix isn't enabled in my test environment.

Additional information

Inputs:
Tested model: LFM2.5-VL-450M-F16 with F16 mmproj.
Tested images:
test
Tested prompts: Describe this image in detail.


Here is the debugging process to help explain the editings. I calculated the cosine similarity between the embedding layers of the CPU backend and the WebGPU backend.

Without any changes on the master branch,

CLIP Vision Stage Comparison (WebGPU vs. C++ Parity) (Table formated by LLM)

Stage Type Tensor Shape Cosine Similarity Status
Kcur-0 f32 [64, 12, 1024, 1] 1.00000000 ✅ Perfect
Vcur-0 f32 [64, 12, 1024, 1] 0.98777701 ⚠️ Slight Drift
attn_out-0 f32 [768, 1024, 1, 1] 0.99609993 ✅ Pass
ffn_inp-0 f32 [768, 1024, 1, 1] 0.96016181 ❌ Significant Drift
ffn_inp_normed-0 f32 [768, 1024, 1, 1] 0.75001117 ❌ Critical Failure
ffn_out-0 f32 [768, 1024, 1, 1] 0.86092414 ❌ Critical Failure

Results from models: The image shows a variety of people with different styles of clothing and accessories, but no specific details about the individuals. The image is primarily focused on a collection of abstract geometric shapes, including a group of people, that are not clearly defined or detailed. These shapes appear to be the main focus of the image.

From these logs, we can see the cosine similarity discrepancy came from two main computation layers: attention, ffn. Related shaders include flash_attn, binary, norm, unary (like GELU). For example, f16 is more performance but less precise in online softmax operation for flash attention, and we noticed accumulated drifts. Therefore, first step was to use f32 and update the shared memory calculation logics for f32 buffers.

I first started with the flash_attn_tile and vec paths since I started debugging in the browser. And the results for attention layers (attn_out and also Vcur-0) had been increased from 0.98 to 1, which also increased the 1st layer precision (ffn_inp-0) after the attention layer.

CLIP Vision Stage Comparison (Updated Run)

Stage Type Tensor Shape Cosine Similarity Status
Kcur-0 f32 [64, 12, 1024, 1] 1.00000000 ✅ Pass
Vcur-0 f32 [64, 12, 1024, 1] 0.99892639 ✅ Pass
attn_out-0 f32 [768, 1024, 1, 1] 0.99965949 ✅ Pass
ffn_inp-0 f32 [768, 1024, 1, 1] 0.99488728 ✅ Pass
ffn_inp_normed-0 f32 [768, 1024, 1, 1] 0.75360093 ❌ Critical Failure
ffn_out-0 f32 [768, 1024, 1, 1] 0.86667313 ❌ Critical Failure
layer_out-0 f32 [768, 1024, 1, 1] 0.76321508 ❌ Critical Failure

After some debugging and analysis, I then corrected gelu, gelu_quick and gelu_erf functions and used the pytorch implementation GELU — PyTorch 2.11 documentation

CLIP Vision Stage Comparison (Updated Run)

Stage Type Tensor Shape Cosine Similarity Status
Kcur-0 f32 [64, 12, 1024, 1] 1.00000000 ✅ Pass
Vcur-0 f32 [64, 12, 1024, 1] 0.99892639 ✅ Pass
attn_out-0 f32 [768, 1024, 1, 1] 0.99965949 ✅ Pass
ffn_inp-0 f32 [768, 1024, 1, 1] 0.99488728 ✅ Pass
ffn_inp_normed-0 f32 [768, 1024, 1, 1] 0.99764710 ✅ Pass
ffn_out-0 f32 [768, 1024, 1, 1] 0.99951498 ✅ Pass
layer_out-0 f32 [768, 1024, 1, 1] 0.99886686 ✅ Pass

Results: The image features a character from the video game "The Legend of Zelda: Breath of the Wild". The character is depicted in a fantasy setting with a mystical ambiance. The character is standing in front of ancient ruins and surrounded by lush greenery and blue-lit trees, suggesting a serene

Correcting the gelu functions improved the accuracy, but the final result was still incorrect, as we noticed that there were still some small offsets in the Vcur-* layer, and these accumulated drifts caused final errors.

| embedding | f32 | [768, 1024, 1, 1] | 0.74568836 | ❌ Critical Failure |

The final root cause for this issue was that flash_attn_tile.wgsl sized SCORE_REGS_PER_LANE from MAX_SUBGROUP_SIZE, but the browser can run with a smaller runtime subgroup_size. For KV_TILE=64, that can make the tile process only part of K/V. I think this was why we saw some slight offsets on Vcur-* and attn-out. So in this path, the shader now sizes per-lane arrays from MIN_SUBGROUP_SIZE.

Also fixed the flash-attn pipeline cache key to include tile compile constants (q_tile, kv_tile, wg_size, subgroup sizes, SG matrix dims), and tile is now preferred over vec when tile is valid.

Results: The image showcases a detailed digital illustration of a female warrior clad in elaborate, dark-toned armor. She wields a sword with a glowing blue blade, suggesting a supernatural or magical element. The setting is an ancient, possibly mystical, stone structure with columns and arches that frame her figure,

| Vcur-0 | f32 | [64, 12, 1024, 1] | 0.9999999 | ✅ Pass |
...
| layer_out-11 | f32 | [768, 1024, 1, 1] | 0.9999999 | ✅ Pass |
| embedding | f32 | [768, 1024, 1, 1] | 0.99999999 | ✅ Pass |

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: Yes, I used AI to help me with the log analysis and keep clean editing on the files.

@Constannnnnt Constannnnnt requested a review from a team as a code owner May 7, 2026 16:57
@github-actions github-actions Bot added ggml changes relating to the ggml tensor library for machine learning WebGPU labels May 7, 2026
Copy link
Copy Markdown
Contributor

@reeselevine reeselevine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a deep investigation, cool to see how much better the description is with the changes. The PR doesn't say, I'm guessing this is on an NVIDIA GPU?

I do have a couple of high-level comments beyond the minor comments on the code.

One that I think needs to be addressed in this PR:

  • The idea behind the vec path for flash attention is that it should increase performance during typical decode scenarios, when the Q sequence length is only 1. But if I'm understanding correctly, this PR changes to always prefer the tile path over the vec path, which I think means the vec path will basically never run?
    • The way this is supposed to work is that vec should be the priority if sequence length is 1, then subgroup matrix should be preferred to tiling if sequence length is > 1 and subgroup matrices are supported. This PR inverts some of that ordering. I realize the subgroup matrix path in particular may have precision issues, which I address in my point below. But I think we should strive to keep the priority in performance order if at all possible, and make adjustments to the shaders for precision to really try and maintain that order.

One that might not need to be addressed in this PR, but which I think will be important moving forward:

  • The logic around path selection, tile size, and decisions for flash attention has gotten quite complicated, and I think this is at least in part due to the mixing of pipelines for the vec and non-vec pipelines. I think if we split up the logic similar to how it is done for matrix-matrix vs. matrix-vector multiplication, that will make the code much clearer and allow for easier changes moving forward.

One that should not be addressed in this PR but we should think about:

  • Precision of intermediate states clearly can make a large difference depending on the device and model. For example, it's not clear that the subgroup matrix path will even work satisfactorily on many devices, because it computes in f16 precision. It's also not clear that f32 is really needed everywhere, but today you pay the memory overhead/performance cost of it no matter what.
    • Ideally, the precision would be chosen dynamically in a way that maximizes speed and stability across devices, but I realize this is probably a larger (research?) project. But I think it's worth keeping in mind as we make changes.

@ArberSephirotheca for visibility, and also if you have thoughts or comments on this PR since you wrote the vec and tile attention shaders.

0.044715 * src[params.offset_src + src_idx] *
src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
-9.010913, 9.010913)));
let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.70710678));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason the multiplication constant 0.707 is in the call to erf_approx rather than within it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I follow the equation on most of the websites, e.g., https://alaaalatif.github.io/2019-04-11-gelu/, the 1/sqrt(2) is outside of the erf function.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the ci did not pass:

Error while parsing WGSL: :64:74 error: type mismatch for argument 1 in call to 'erf_approx', expected 'f32', got 'f16' let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.70710678));

I will correct this.

uint32_t q_tile;
uint32_t kv_tile;
uint32_t wg_size;
uint32_t min_subgroup_size;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these subgroup and the following sg_mat fields necessary in the key? They are fixed for a given WebGPU device so I don't think they should affect which pipeline is chosen, at least in principle.

Although they might be just proxying whether subgroups/subgroup matrices are are actually supported?

inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u;

inline uint32_t ggml_webgpu_effective_min_subgroup_size(const ggml_webgpu_shader_lib_context & context) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these functions needed? If subgroups are supported, I think we can assume that the reported subgroup size is > 0, otherwise that seems like a bug in WebGPU itself.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The initial motivation was to use min subgroup for calculating the per-lane arrays in tile, i.e. line 122-123 in flash_attn_tile.wgsl. because I am not familiar with this, so I did some research about min/max numbers on this topic, and I found this reference on different hardware: https://docs.rs/wgpu/latest/wgpu/struct.AdapterInfo.html. The 0 here is simply for a check. You can notice that these two functions are basically "cross-fallback" logic.

I might overthink this; it seems reasonable to remove them after second thoughts.

@reeselevine reeselevine changed the title ggml-webgpu: address precision issues for multimodel ggml-webgpu: address precision issues for multimodal May 8, 2026
@Constannnnnt
Copy link
Copy Markdown
Contributor Author

Constannnnnt commented May 11, 2026

Thank you for these insights!!

I think I can answer the first question on vec or tile path; actually, before this fix, my workaround entirely relied on the vec path because the vec path did not have the partial kv problem, as it does not need subgroup support (correct me if I am wrong). As discussed with @ArberSephirotheca in #22199, in some of my test cases, the sequence length of a multimodal request is sometimes longer; I thought it would be good to know what happened within the tile path and reverted the order for testing purposes. I did not notice the design of the performance priority, but now I get it. And one question here: why is sequence length is 1 for vec? I am not sure if I misunderstand the sequence length here. Thanks! Oh, and yes, I used an NVIDIA GPU.

@reeselevine
Copy link
Copy Markdown
Contributor

yeah sorry I was misremembering the vec path, it works for sequence lengths (of Q) greater than 1 as well. Hopefully the performance of the vec path is faster in the short sequence cases, if it's not on any of the machines you're testing we might want to revisit it's design :). But thanks for the fix on the tile path, always nice to fix bugs and fix stability issues.

@Constannnnnt
Copy link
Copy Markdown
Contributor Author

I updated the unary shader using the editorconfig, removed the inline functions and redundant pipeline keys, and reverted the flash attn path order.

@reeselevine reeselevine requested review from CISC and ggerganov May 12, 2026 04:18
@reeselevine reeselevine merged commit 239a497 into ggml-org:master May 12, 2026
45 of 46 checks passed
xxmustafacooTR pushed a commit to xxPlayground/llama-cpp-turboquant that referenced this pull request May 12, 2026
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32

* fix(unary): correct the gelu, gelu quick and gelu erf functions

* fix(flash-attn-tile): fix the hardcode v type

* fix(flash_attn): fix tile path

* fix: pass editorconfig and address the type conflicts

* fix: remove reduant pipeline keys

* fix: remove inline min/max group size functions and revert the flash attn path order

* fix: use clamp to avoid NaN for GELU

* fix: use the right range for exp, 80 is safer for f32 exp
@Constannnnnt Constannnnnt deleted the webgpu/fix-mul-max branch May 12, 2026 19:12
@ArberSephirotheca
Copy link
Copy Markdown
Contributor

ArberSephirotheca commented May 13, 2026

@Constannnnnt Sorry, I missed this discussion earlier. The changes look great overall! I have one question about the tile-path precision change: after the MIN_SUBGROUP_SIZE fix, did you compare results with kv_shmem staged as f32 versus KV_TYPE?

I’m trying to understand whether f32 staging is still independently required for precision, or whether we can keep kv_shmem in KV_TYPE to reduce workgroup memory usage and avoid the extra f16->f32 staging conversion.

@Constannnnnt
Copy link
Copy Markdown
Contributor Author

@ArberSephirotheca Actually, I am not sure about this, as I did not test this. I can test this by this week and share the results with you.

Jcfunk pushed a commit to Jcfunk/llama.cpp that referenced this pull request May 13, 2026
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32

* fix(unary): correct the gelu, gelu quick and gelu erf functions

* fix(flash-attn-tile): fix the hardcode v type

* fix(flash_attn): fix tile path

* fix: pass editorconfig and address the type conflicts

* fix: remove reduant pipeline keys

* fix: remove inline min/max group size functions and revert the flash attn path order

* fix: use clamp to avoid NaN for GELU

* fix: use the right range for exp, 80 is safer for f32 exp
@Constannnnnt
Copy link
Copy Markdown
Contributor Author

Hey @ArberSephirotheca , finally got time to report back. So here are my changes:

var<workgroup> q_shmem: array<Q_TYPE, Q_TILE * HEAD_DIM_QK>;
var<workgroup> kv_shmem: array<KV_TYPE, KV_TILE * KV_STAGE_STRIDE>;
var<workgroup> p_shmem: array<KV_TYPE, Q_TILE * KV_TILE>;
...
q_shmem[elem_idx] = select(
            0.0,
            Q_TYPE(Q[global_q_row_offset + q_col]) * params.scale,
            head_q_row < params.seq_len_q);
...
kv_shmem[kv_off + 0u] = KV_TYPE(k4.x);
...
let qv = vec4<Q_TYPE>(
    q_shmem[q_off + 0u],
    q_shmem[q_off + 1u],
    q_shmem[q_off + 2u],
    q_shmem[q_off + 3u]);
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
let kv = vec4<KV_TYPE>(
    kv_shmem[kv_off + 0u],
    kv_shmem[kv_off + 1u],
    kv_shmem[kv_off + 2u],
    kv_shmem[kv_off + 3u]);
dot_val += dot(vec4<f32>(qv), vec4<f32>(kv));
...
p_shmem[subgroup_p_offset + kv_local] = KV_TYPE(p);
...
kv_shmem[kv_off + 0u] = KV_TYPE(v4.x);
kv_shmem[kv_off + 1u] = KV_TYPE(v4.y);
kv_shmem[kv_off + 2u] = KV_TYPE(v4.z);
kv_shmem[kv_off + 3u] = KV_TYPE(v4.w);
...
let v4 = vec4<KV_TYPE>(
      kv_shmem[kv_off + 0u],
      kv_shmem[kv_off + 1u],
      kv_shmem[kv_off + 2u],
      kv_shmem[kv_off + 3u]);
acc += f32(p) * vec4<f32>(v4);

So basically, use the macro types for most values, and as for the calculation for dot and acc, use f32.

I did not see any accuracy differences between f32 and KV_TYPE directly here.

However, when I set dot_val += f32(dot(vec4<f16>(qv), vec4<f16>(kv))); and `acc += vec4(f16(p) * vec4(v4)); ", The cosine similarity slightly decreases:
cos=0.99999821 vs cos=0.99999707
cos=0.99999874 vs cos=0.99999793
cos=0.99999708 vs cos=0.99999757
cos=0.99999772 vs cos=0.99999564
cos=0.99999851 vs cos=0.99999788
cos=0.99999815 vs cos=0.99999734
cos=0.99999912 vs cos=0.99999836

And results look quite similar (as I did not have a GT for this, so no scores).
In a fantasy-themed image, a female warrior stands in the foreground. She is clad in dark armor that covers her entire body, with intricate designs and patterns visible on the chest plate and skirt. The armor has a metallic sheen, suggesting it is made of steel or a similar material. ...

vs

The image features a female warrior dressed in elaborate, dark armor and holding a sword with a glowing blade. She stands before an ancient stone archway, with trees and a twilight sky in the background, creating a mystical and heroic atmosphere. ...

Let me know if you want to have more details. Thanks.

rsenthilkumar6 pushed a commit to rsenthilkumar6/llama.cpp that referenced this pull request May 19, 2026
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32

* fix(unary): correct the gelu, gelu quick and gelu erf functions

* fix(flash-attn-tile): fix the hardcode v type

* fix(flash_attn): fix tile path

* fix: pass editorconfig and address the type conflicts

* fix: remove reduant pipeline keys

* fix: remove inline min/max group size functions and revert the flash attn path order

* fix: use clamp to avoid NaN for GELU

* fix: use the right range for exp, 80 is safer for f32 exp
@reeselevine
Copy link
Copy Markdown
Contributor

If we can keep the KV-cache shared memory as f16 with no meaningful reduction in accuracy, that would be great, since it reduces memory requirements and bandwidth across the board.

@Constannnnnt
Copy link
Copy Markdown
Contributor Author

Yeah, sg. I will create a PR later tonight for this.

baramofme pushed a commit to baramofme/llama-cpp-turboquant that referenced this pull request May 23, 2026
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32

* fix(unary): correct the gelu, gelu quick and gelu erf functions

* fix(flash-attn-tile): fix the hardcode v type

* fix(flash_attn): fix tile path

* fix: pass editorconfig and address the type conflicts

* fix: remove reduant pipeline keys

* fix: remove inline min/max group size functions and revert the flash attn path order

* fix: use clamp to avoid NaN for GELU

* fix: use the right range for exp, 80 is safer for f32 exp
winstonma pushed a commit to winstonma/llama.cpp that referenced this pull request May 27, 2026
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32

* fix(unary): correct the gelu, gelu quick and gelu erf functions

* fix(flash-attn-tile): fix the hardcode v type

* fix(flash_attn): fix tile path

* fix: pass editorconfig and address the type conflicts

* fix: remove reduant pipeline keys

* fix: remove inline min/max group size functions and revert the flash attn path order

* fix: use clamp to avoid NaN for GELU

* fix: use the right range for exp, 80 is safer for f32 exp
fewtarius pushed a commit to fewtarius/llama.cpp that referenced this pull request May 30, 2026
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32

* fix(unary): correct the gelu, gelu quick and gelu erf functions

* fix(flash-attn-tile): fix the hardcode v type

* fix(flash_attn): fix tile path

* fix: pass editorconfig and address the type conflicts

* fix: remove reduant pipeline keys

* fix: remove inline min/max group size functions and revert the flash attn path order

* fix: use clamp to avoid NaN for GELU

* fix: use the right range for exp, 80 is safer for f32 exp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning WebGPU

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants